{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    },
    "toc": {
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "gnn-basic-edge-1.ipynb",
      "provenance": []
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0823b6e94bdb4b7c9a8738194bc4d09c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_dc736354081a4d28a99a071fefc53ee8",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_5c379bfdb6354aac87afff41381809d2",
              "IPY_MODEL_e0732bf6c30940f5b1c1260024647d2f"
            ]
          }
        },
        "dc736354081a4d28a99a071fefc53ee8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "5c379bfdb6354aac87afff41381809d2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_a98e75a5b77849bcae5917e99d05c772",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_494990f144ee4fd789a16d1cb756f985"
          }
        },
        "e0732bf6c30940f5b1c1260024647d2f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_0154358d3957437f8b866a2c171aaf1c",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 9920512/? [00:20&lt;00:00, 2913644.28it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_dc239abe8e2e48cf85854ca2bfcab31b"
          }
        },
        "a98e75a5b77849bcae5917e99d05c772": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "494990f144ee4fd789a16d1cb756f985": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "0154358d3957437f8b866a2c171aaf1c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "dc239abe8e2e48cf85854ca2bfcab31b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c96ce66789cc408190a4b70d352f16ab": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_b5a87422351a4c0f90427364cc80b69e",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_4b111c8230a0468fba592f874c3ef6f6",
              "IPY_MODEL_9bb5cd6d77644d558a01c7c661b4c359"
            ]
          }
        },
        "b5a87422351a4c0f90427364cc80b69e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "4b111c8230a0468fba592f874c3ef6f6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_9dbb30e97f6544959fac7bf55b55cbb7",
            "_dom_classes": [],
            "description": "  0%",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 0,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_42b697e9795d4f1a8ba273aee20fade9"
          }
        },
        "9bb5cd6d77644d558a01c7c661b4c359": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_076602a8b58840b49f338216f26e176d",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 0/28881 [00:00&lt;?, ?it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_0aec7b08588f471d8934708a5f84decd"
          }
        },
        "9dbb30e97f6544959fac7bf55b55cbb7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "42b697e9795d4f1a8ba273aee20fade9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "076602a8b58840b49f338216f26e176d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "0aec7b08588f471d8934708a5f84decd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c00e3ae7631948a4a3367172922f5c09": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_07e9947cb4c6444282ca2f2caffaa23c",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_9c43d492489b46b893014a628ca92ac7",
              "IPY_MODEL_531020a743284445aef3033937fc16ee"
            ]
          }
        },
        "07e9947cb4c6444282ca2f2caffaa23c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "9c43d492489b46b893014a628ca92ac7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_b4764f34a84f4f118da3928fbec24959",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_def1013653de4754ae0bb183bb9cc4e0"
          }
        },
        "531020a743284445aef3033937fc16ee": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_e9e9a516d1e64c21a0b0c48ea42518c2",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 1654784/? [00:00&lt;00:00, 2862067.05it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_5cbf629ce64247a9a48e327b8f817aed"
          }
        },
        "b4764f34a84f4f118da3928fbec24959": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "def1013653de4754ae0bb183bb9cc4e0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "e9e9a516d1e64c21a0b0c48ea42518c2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "5cbf629ce64247a9a48e327b8f817aed": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "e16159e7e9aa4fac9a35e130d7528d12": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_6af8235767034e4fa2e48e98b741216b",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_0fbcb8effc1d422e9639cadcb013d953",
              "IPY_MODEL_54690430902d4abfbfa8699c66e27cd0"
            ]
          }
        },
        "6af8235767034e4fa2e48e98b741216b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "0fbcb8effc1d422e9639cadcb013d953": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_52e1b9b521534e77844796b10f8ad1cc",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_98e735e76e3a4e0c8681ccfe4831c2ea"
          }
        },
        "54690430902d4abfbfa8699c66e27cd0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_b0febc37e00a43b7bb828203d741db31",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 8192/? [00:00&lt;00:00, 31989.00it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_243a31da935a45caa6939bd8e519aa3c"
          }
        },
        "52e1b9b521534e77844796b10f8ad1cc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "98e735e76e3a4e0c8681ccfe4831c2ea": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "b0febc37e00a43b7bb828203d741db31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "243a31da935a45caa6939bd8e519aa3c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kxuIBwM_g0AZ",
        "colab_type": "text"
      },
      "source": [
        "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
        "- Author: Sebastian Raschka\n",
        "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XKblzChShlT-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install -q IPython\n",
        "!pip install -q ipykernel\n",
        "!pip install -q watermark\n",
        "!pip install -q matplotlib\n",
        "!pip install -q sklearn\n",
        "!pip install -q pandas\n",
        "!pip install -q pydot\n",
        "!pip install -q hiddenlayer\n",
        "!pip install -q graphviz"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d4gj8EPQg0Aa",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        },
        "outputId": "1579528d-533a-4a5d-97ce-20009aebca2b"
      },
      "source": [
        "%load_ext watermark\n",
        "%watermark -a 'Sebastian Raschka' -v -p torch"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sebastian Raschka \n",
            "\n",
            "CPython 3.6.9\n",
            "IPython 5.5.0\n",
            "\n",
            "torch 1.5.1+cu101\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hb3xbf_fg0Af",
        "colab_type": "text"
      },
      "source": [
        "- Runs on CPU or GPU (if available)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uk5jAF_Qg0Ag",
        "colab_type": "text"
      },
      "source": [
        "# Basic Graph Neural Network with Edge Prediction on MNIST"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iMNeGWi9g0Ag",
        "colab_type": "text"
      },
      "source": [
        "Implementing a very basic graph neural network (GNN) using a subnetwork for edge prediction. \n",
        "\n",
        "Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1]. \n",
        "\n",
        "In the related notebook, [gnn-basic-1.ipyb], the adjacency matrix of the pixels was basically just determined by the neighborhood pixels. Using a Gaussian filter, pixels  were connected based on their Euclidean distance in the grid. In **this notebook**, the edges are predicted via a seperate neural network model \n",
        "\n",
        "\n",
        "```python\n",
        "        self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 64),\n",
        "                                          nn.ReLU(),\n",
        "                                          nn.Linear(64, 1),\n",
        "                                          nn.Tanh())\n",
        "```\n",
        "\n",
        "\n",
        "Using the resulting adjacency matrix $A$, we can compute the output of a layer as \n",
        "\n",
        "$$X^{(l+1)}=A X^{(l)} W^{(l)}.$$\n",
        "\n",
        "Here, $A$ is the $N \\times N$ adjacency matrix, and $X$ is the $N \\times C$ feature matrix (a  2D coordinate array, where $N$ is the total number of pixels -- $28 \\times 28 = 784$ in MNIST). $W$ is the weight matrix of shape $N \\times P$, where $P$ would represent the number of classes if we have only a single hidden layer.\n",
        "\n",
        "\n",
        "- Inspired by and based on Boris Knyazev's tutorial at https://medium.com/@BorisAKnyazev/tutorial-on-graph-neural-networks-for-computer-vision-and-beyond-part-1-3d9fada3b80d."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jIO5LnEyg0Ah",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GytGe5TYg0Ah",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import time\n",
        "import numpy as np\n",
        "from scipy.spatial.distance import cdist\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torchvision import datasets\n",
        "from torchvision import transforms\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.utils.data.dataset import Subset\n",
        "\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    torch.backends.cudnn.deterministic = True"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GJ0NLx7xg0Ak",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2GVeisNTg0An",
        "colab_type": "text"
      },
      "source": [
        "## Settings and Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1GE0l5kbg0An",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##########################\n",
        "### SETTINGS\n",
        "##########################\n",
        "\n",
        "# Device\n",
        "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "# Hyperparameters\n",
        "RANDOM_SEED = 1\n",
        "LEARNING_RATE = 0.0005\n",
        "NUM_EPOCHS = 50\n",
        "BATCH_SIZE = 128\n",
        "IMG_SIZE = 28\n",
        "\n",
        "# Architecture\n",
        "NUM_CLASSES = 10"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2IbQPfmtg0Ar",
        "colab_type": "text"
      },
      "source": [
        "## MNIST Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qLtk6EACg0As",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 424,
          "referenced_widgets": [
            "0823b6e94bdb4b7c9a8738194bc4d09c",
            "dc736354081a4d28a99a071fefc53ee8",
            "5c379bfdb6354aac87afff41381809d2",
            "e0732bf6c30940f5b1c1260024647d2f",
            "a98e75a5b77849bcae5917e99d05c772",
            "494990f144ee4fd789a16d1cb756f985",
            "0154358d3957437f8b866a2c171aaf1c",
            "dc239abe8e2e48cf85854ca2bfcab31b",
            "c96ce66789cc408190a4b70d352f16ab",
            "b5a87422351a4c0f90427364cc80b69e",
            "4b111c8230a0468fba592f874c3ef6f6",
            "9bb5cd6d77644d558a01c7c661b4c359",
            "9dbb30e97f6544959fac7bf55b55cbb7",
            "42b697e9795d4f1a8ba273aee20fade9",
            "076602a8b58840b49f338216f26e176d",
            "0aec7b08588f471d8934708a5f84decd",
            "c00e3ae7631948a4a3367172922f5c09",
            "07e9947cb4c6444282ca2f2caffaa23c",
            "9c43d492489b46b893014a628ca92ac7",
            "531020a743284445aef3033937fc16ee",
            "b4764f34a84f4f118da3928fbec24959",
            "def1013653de4754ae0bb183bb9cc4e0",
            "e9e9a516d1e64c21a0b0c48ea42518c2",
            "5cbf629ce64247a9a48e327b8f817aed",
            "e16159e7e9aa4fac9a35e130d7528d12",
            "6af8235767034e4fa2e48e98b741216b",
            "0fbcb8effc1d422e9639cadcb013d953",
            "54690430902d4abfbfa8699c66e27cd0",
            "52e1b9b521534e77844796b10f8ad1cc",
            "98e735e76e3a4e0c8681ccfe4831c2ea",
            "b0febc37e00a43b7bb828203d741db31",
            "243a31da935a45caa6939bd8e519aa3c"
          ]
        },
        "outputId": "1333f38e-c056-41d3-bae6-d8a80c9f6101"
      },
      "source": [
        "train_indices = torch.arange(0, 59000)\n",
        "valid_indices = torch.arange(59000, 60000)\n",
        "\n",
        "custom_transform = transforms.Compose([transforms.ToTensor()])\n",
        "\n",
        "\n",
        "train_and_valid = datasets.MNIST(root='data', \n",
        "                                 train=True, \n",
        "                                 transform=custom_transform,\n",
        "                                 download=True)\n",
        "\n",
        "test_dataset = datasets.MNIST(root='data', \n",
        "                              train=False, \n",
        "                              transform=custom_transform,\n",
        "                              download=True)\n",
        "\n",
        "train_dataset = Subset(train_and_valid, train_indices)\n",
        "valid_dataset = Subset(train_and_valid, valid_indices)\n",
        "\n",
        "train_loader = DataLoader(dataset=train_dataset, \n",
        "                          batch_size=BATCH_SIZE,\n",
        "                          num_workers=4,\n",
        "                          shuffle=True)\n",
        "\n",
        "valid_loader = DataLoader(dataset=valid_dataset, \n",
        "                          batch_size=BATCH_SIZE,\n",
        "                          num_workers=4,\n",
        "                          shuffle=False)\n",
        "\n",
        "test_loader = DataLoader(dataset=test_dataset, \n",
        "                         batch_size=BATCH_SIZE,\n",
        "                         num_workers=4,\n",
        "                         shuffle=False)\n",
        "\n",
        "# Checking the dataset\n",
        "for images, labels in train_loader:  \n",
        "    print('Image batch dimensions:', images.shape)\n",
        "    print('Image label dimensions:', labels.shape)\n",
        "    break"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0823b6e94bdb4b7c9a8738194bc4d09c",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c96ce66789cc408190a4b70d352f16ab",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c00e3ae7631948a4a3367172922f5c09",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e16159e7e9aa4fac9a35e130d7528d12",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Processing...\n",
            "Done!\n",
            "\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/pytorch/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Image batch dimensions: torch.Size([128, 1, 28, 28])\n",
            "Image label dimensions: torch.Size([128])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nVLcPyiXg0Au",
        "colab_type": "text"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GFicewkmg0Av",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##########################\n",
        "### MODEL\n",
        "##########################\n",
        "\n",
        "\n",
        "def make_coordinate_array(img_size, out_size=4):\n",
        "    \n",
        "    ### Make 2D coordinate array (for MNIST: 784x2)\n",
        "    n_rows = img_size * img_size\n",
        "    col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))\n",
        "    coord = np.stack((col, row), axis=2).reshape(-1, 2)\n",
        "    coord = (coord - np.mean(coord, axis=0)) / (np.std(coord, axis=0) + 1e-5)\n",
        "    coord = torch.from_numpy(coord).float()\n",
        "    \n",
        "    ### Reshape to [N, N, out_size]\n",
        "    coord = torch.cat((coord.unsqueeze(0).repeat(n_rows, 1,  int(out_size/2-1)),\n",
        "                            coord.unsqueeze(1).repeat(1, n_rows, 1)), dim=2)\n",
        "    \n",
        "    \n",
        "    return coord\n",
        "\n",
        "        \n",
        "\n",
        "class GraphNet(nn.Module):\n",
        "    def __init__(self, img_size=28, coord_features=4, num_classes=10):\n",
        "        super(GraphNet, self).__init__()\n",
        "        \n",
        "        n_rows = img_size**2\n",
        "        self.fc = nn.Linear(n_rows, num_classes, bias=False)\n",
        "\n",
        "        coord = make_coordinate_array(img_size, coord_features)\n",
        "        self.register_buffer('coord', coord)\n",
        "        \n",
        "        ##########\n",
        "        # Edge Predictor\n",
        "        self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 32), # coord -> hidden\n",
        "                                          nn.ReLU(),\n",
        "                                          nn.Linear(32, 1), # hidden -> edge\n",
        "                                          nn.Tanh())\n",
        "        \n",
        "\n",
        "        \n",
        "\n",
        "    def forward(self, x):\n",
        "        B = x.size(0)\n",
        "        \n",
        "        ### Predict edges\n",
        "        self.A = self.pred_edge_fc(self.coord).squeeze()\n",
        "\n",
        "        ### Reshape Adjacency Matrix\n",
        "        # [N, N] Adj. matrix -> [1, N, N] Adj tensor where N = HxW\n",
        "        A_tensor = self.A.unsqueeze(0)\n",
        "        # [1, N, N] Adj tensor -> [B, N, N] tensor\n",
        "        A_tensor = self.A.expand(B, -1, -1)\n",
        "        \n",
        "        ### Reshape inputs\n",
        "        # [B, C, H, W] => [B, H*W, 1]\n",
        "        x_reshape = x.view(B, -1, 1)\n",
        "        \n",
        "        # bmm = batch matrix product to sum the neighbor features\n",
        "        # Input: [B, N, N] x [B, N, 1]\n",
        "        # Output: [B, N]\n",
        "        avg_neighbor_features = (torch.bmm(A_tensor, x_reshape).view(B, -1))\n",
        "\n",
        "        logits = self.fc(avg_neighbor_features)\n",
        "        probas = F.softmax(logits, dim=1)\n",
        "        return logits, probas"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UFc5IWyMg0Ax",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "torch.manual_seed(RANDOM_SEED)\n",
        "model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES)\n",
        "\n",
        "model = model.to(DEVICE)\n",
        "\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7-T2cJVdhtv4",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 762
        },
        "outputId": "dde4b58e-ad5a-4002-d862-a7c9a7d2b435"
      },
      "source": [
        "import hiddenlayer as hl\n",
        "hl.build_graph(model, torch.zeros([128, 1, 28, 28]).to(DEVICE))"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<hiddenlayer.graph.Graph at 0x7f862777bf60>"
            ],
            "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"2047pt\" height=\"540pt\"\n viewBox=\"0.00 0.00 2047.00 540.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 504)\">\n<title>%3</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-72,36 -72,-504 1975,-504 1975,36 -72,36\"/>\n<!-- /outputs/7 -->\n<g id=\"node1\" class=\"node\">\n<title>/outputs/7</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-225 0,-225 0,-189 54,-189 54,-225\"/>\n<text text-anchor=\"start\" x=\"15\" y=\"-204\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Shape</text>\n</g>\n<!-- /outputs/9 -->\n<g id=\"node3\" class=\"node\">\n<title>/outputs/9</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"145,-198 91,-198 91,-162 145,-162 145,-198\"/>\n<text text-anchor=\"start\" x=\"104\" y=\"-177\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Gather</text>\n</g>\n<!-- /outputs/7&#45;&gt;/outputs/9 -->\n<g id=\"edge1\" class=\"edge\">\n<title>/outputs/7&#45;&gt;/outputs/9</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M54.303,-198.8991C62.6713,-196.4162 72.0227,-193.6416 80.9279,-190.9994\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"82.1249,-194.2951 90.7163,-188.0952 80.1338,-187.5843 82.1249,-194.2951\"/>\n</g>\n<!-- /outputs/8 -->\n<g id=\"node2\" class=\"node\">\n<title>/outputs/8</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-171 0,-171 0,-135 54,-135 54,-171\"/>\n<text text-anchor=\"start\" x=\"9\" y=\"-150\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/8&#45;&gt;/outputs/9 -->\n<g id=\"edge2\" class=\"edge\">\n<title>/outputs/8&#45;&gt;/outputs/9</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M54.303,-161.1009C62.6713,-163.5838 72.0227,-166.3584 80.9279,-169.0006\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"80.1338,-172.4157 90.7163,-171.9048 82.1249,-165.7049 80.1338,-172.4157\"/>\n</g>\n<!-- /outputs/21 -->\n<g id=\"node15\" class=\"node\">\n<title>/outputs/21</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-198 182,-198 182,-162 243,-162 243,-198\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-177\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/9&#45;&gt;/outputs/21 -->\n<g id=\"edge3\" class=\"edge\">\n<title>/outputs/9&#45;&gt;/outputs/21</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M145.0799,-180C153.3121,-180 162.5498,-180 171.4848,-180\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"171.7564,-183.5001 181.7563,-180 171.7563,-176.5001 171.7564,-183.5001\"/>\n</g>\n<!-- /outputs/36 -->\n<g id=\"node30\" class=\"node\">\n<title>/outputs/36</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-360 182,-360 182,-324 243,-324 243,-360\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-339\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/9&#45;&gt;/outputs/36 -->\n<g id=\"edge4\" class=\"edge\">\n<title>/outputs/9&#45;&gt;/outputs/36</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M138.3639,-198.2296C140.8087,-201.0055 143.1054,-203.9627 145,-207 171.854,-250.0497 153.85,-272.7863 182,-315 182.1913,-315.2868 182.3863,-315.5725 182.5849,-315.857\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"180.0539,-318.2866 189.1507,-323.7179 185.4264,-313.7992 180.0539,-318.2866\"/>\n</g>\n<!-- /outputs/43 -->\n<g id=\"node37\" class=\"node\">\n<title>/outputs/43</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-144 182,-144 182,-108 243,-108 243,-144\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-123\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/9&#45;&gt;/outputs/43 -->\n<g id=\"edge5\" class=\"edge\">\n<title>/outputs/9&#45;&gt;/outputs/43</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M145.0799,-164.5258C153.7595,-159.566 163.5569,-153.9675 172.9387,-148.6064\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"174.8104,-151.5681 181.7563,-143.5678 171.3374,-145.4904 174.8104,-151.5681\"/>\n</g>\n<!-- /outputs/10 -->\n<g id=\"node4\" class=\"node\">\n<title>/outputs/10</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"241.5,-306 183.5,-306 183.5,-270 241.5,-270 241.5,-306\"/>\n<text text-anchor=\"start\" x=\"191.5\" y=\"-285\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n</g>\n<!-- /outputs/11 -->\n<g id=\"node5\" class=\"node\">\n<title>/outputs/11</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"355,-308 301,-308 301,-272 355,-272 355,-308\"/>\n<text text-anchor=\"start\" x=\"311\" y=\"-287\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n</g>\n<!-- /outputs/10&#45;&gt;/outputs/11 -->\n<g id=\"edge6\" class=\"edge\">\n<title>/outputs/10&#45;&gt;/outputs/11</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M241.6455,-288.5047C256.5565,-288.7629 274.8713,-289.08 290.7913,-289.3557\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"290.7883,-292.8561 300.8474,-289.5298 290.9095,-285.8571 290.7883,-292.8561\"/>\n<text text-anchor=\"middle\" x=\"272\" y=\"-292\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">4x32</text>\n</g>\n<!-- /outputs/12 -->\n<g id=\"node6\" class=\"node\">\n<title>/outputs/12</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"498,-319 444,-319 444,-283 498,-283 498,-319\"/>\n<text text-anchor=\"start\" x=\"462\" y=\"-298\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n</g>\n<!-- /outputs/11&#45;&gt;/outputs/12 -->\n<g id=\"edge7\" class=\"edge\">\n<title>/outputs/11&#45;&gt;/outputs/12</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M355.2338,-292.0949C377.4983,-293.8076 409.2039,-296.2465 433.8066,-298.139\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"433.6286,-301.6355 443.8676,-298.9129 434.1655,-294.6561 433.6286,-301.6355\"/>\n<text text-anchor=\"middle\" x=\"399.5\" y=\"-300\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x784x32</text>\n</g>\n<!-- /outputs/13 -->\n<g id=\"node7\" class=\"node\">\n<title>/outputs/13</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"644,-326 590,-326 590,-290 644,-290 644,-326\"/>\n<text text-anchor=\"start\" x=\"607\" y=\"-305\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Relu</text>\n</g>\n<!-- /outputs/12&#45;&gt;/outputs/13 -->\n<g id=\"edge8\" class=\"edge\">\n<title>/outputs/12&#45;&gt;/outputs/13</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M498.1193,-302.3002C521.0227,-303.3983 554.0591,-304.9823 579.4958,-306.2019\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"579.3743,-309.7 589.5305,-306.683 579.7096,-302.708 579.3743,-309.7\"/>\n<text text-anchor=\"middle\" x=\"543\" y=\"-308\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x784x32</text>\n</g>\n<!-- /outputs/15 -->\n<g id=\"node9\" class=\"node\">\n<title>/outputs/15</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"806,-272 752,-272 752,-236 806,-236 806,-272\"/>\n<text text-anchor=\"start\" x=\"762\" y=\"-251\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n</g>\n<!-- /outputs/13&#45;&gt;/outputs/15 -->\n<g id=\"edge9\" class=\"edge\">\n<title>/outputs/13&#45;&gt;/outputs/15</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M644.1152,-298.9616C670.9823,-290.0059 712.3248,-276.2251 742.1579,-266.2807\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"743.4164,-269.5506 751.7964,-263.0679 741.2027,-262.9098 743.4164,-269.5506\"/>\n<text text-anchor=\"middle\" x=\"690.5\" y=\"-295\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x784x32</text>\n</g>\n<!-- /outputs/14 -->\n<g id=\"node8\" class=\"node\">\n<title>/outputs/14</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"646,-272 588,-272 588,-236 646,-236 646,-272\"/>\n<text text-anchor=\"start\" x=\"596\" y=\"-251\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n</g>\n<!-- /outputs/14&#45;&gt;/outputs/15 -->\n<g id=\"edge10\" class=\"edge\">\n<title>/outputs/14&#45;&gt;/outputs/15</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M646.337,-254C673.054,-254 712.6411,-254 741.6169,-254\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"741.7686,-257.5001 751.7685,-254 741.7685,-250.5001 741.7686,-257.5001\"/>\n<text text-anchor=\"middle\" x=\"690.5\" y=\"-257\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">32x1</text>\n</g>\n<!-- /outputs/16 -->\n<g id=\"node10\" class=\"node\">\n<title>/outputs/16</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"961,-272 907,-272 907,-236 961,-236 961,-272\"/>\n<text text-anchor=\"start\" x=\"925\" y=\"-251\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n</g>\n<!-- /outputs/15&#45;&gt;/outputs/16 -->\n<g id=\"edge11\" class=\"edge\">\n<title>/outputs/15&#45;&gt;/outputs/16</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M806.3541,-254C831.4769,-254 868.9202,-254 896.7638,-254\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"896.9194,-257.5001 906.9194,-254 896.9193,-250.5001 896.9194,-257.5001\"/>\n<text text-anchor=\"middle\" x=\"865\" y=\"-257\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x784x1</text>\n</g>\n<!-- /outputs/17 -->\n<g id=\"node11\" class=\"node\">\n<title>/outputs/17</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1099,-272 1045,-272 1045,-236 1099,-236 1099,-272\"/>\n<text text-anchor=\"start\" x=\"1062\" y=\"-251\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Tanh</text>\n</g>\n<!-- /outputs/16&#45;&gt;/outputs/17 -->\n<g id=\"edge12\" class=\"edge\">\n<title>/outputs/16&#45;&gt;/outputs/17</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M961.2644,-254C982.2445,-254 1011.4706,-254 1034.6299,-254\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1034.7923,-257.5001 1044.7923,-254 1034.7922,-250.5001 1034.7923,-257.5001\"/>\n<text text-anchor=\"middle\" x=\"1003\" y=\"-257\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x784x1</text>\n</g>\n<!-- /outputs/18 -->\n<g id=\"node12\" class=\"node\">\n<title>/outputs/18</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1237,-260 1183,-260 1183,-224 1237,-224 1237,-260\"/>\n<text text-anchor=\"start\" x=\"1193\" y=\"-239\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Squeeze</text>\n</g>\n<!-- /outputs/17&#45;&gt;/outputs/18 -->\n<g id=\"edge13\" class=\"edge\">\n<title>/outputs/17&#45;&gt;/outputs/18</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1099.2644,-251.6292C1120.2445,-249.8048 1149.4706,-247.2634 1172.6299,-245.2496\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1173.1331,-248.7191 1182.7923,-244.3659 1172.5266,-241.7454 1173.1331,-248.7191\"/>\n<text text-anchor=\"middle\" x=\"1141\" y=\"-252\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x784x1</text>\n</g>\n<!-- /outputs/33 -->\n<g id=\"node27\" class=\"node\">\n<title>/outputs/33</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1364,-231 1310,-231 1310,-195 1364,-195 1364,-231\"/>\n<text text-anchor=\"start\" x=\"1321\" y=\"-210\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Expand</text>\n</g>\n<!-- /outputs/18&#45;&gt;/outputs/33 -->\n<g id=\"edge14\" class=\"edge\">\n<title>/outputs/18&#45;&gt;/outputs/33</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1237.2447,-235.7788C1255.5757,-231.593 1279.9686,-226.0229 1300.0446,-221.4386\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1300.8305,-224.8494 1309.8004,-219.2109 1299.2721,-218.025 1300.8305,-224.8494\"/>\n<text text-anchor=\"middle\" x=\"1273.5\" y=\"-233\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x784</text>\n</g>\n<!-- /outputs/19 -->\n<g id=\"node13\" class=\"node\">\n<title>/outputs/19</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"145,-90 91,-90 91,-54 145,-54 145,-90\"/>\n<text text-anchor=\"start\" x=\"100\" y=\"-69\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/22 -->\n<g id=\"node16\" class=\"node\">\n<title>/outputs/22</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-90 182,-90 182,-54 243,-54 243,-90\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-69\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/19&#45;&gt;/outputs/22 -->\n<g id=\"edge15\" class=\"edge\">\n<title>/outputs/19&#45;&gt;/outputs/22</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M145.0799,-72C153.3121,-72 162.5498,-72 171.4848,-72\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"171.7564,-75.5001 181.7563,-72 171.7563,-68.5001 171.7564,-75.5001\"/>\n</g>\n<!-- /outputs/20 -->\n<g id=\"node14\" class=\"node\">\n<title>/outputs/20</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"145,-252 91,-252 91,-216 145,-216 145,-252\"/>\n<text text-anchor=\"start\" x=\"100\" y=\"-231\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/23 -->\n<g id=\"node17\" class=\"node\">\n<title>/outputs/23</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-252 182,-252 182,-216 243,-216 243,-252\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-231\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/20&#45;&gt;/outputs/23 -->\n<g id=\"edge16\" class=\"edge\">\n<title>/outputs/20&#45;&gt;/outputs/23</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M145.0799,-234C153.3121,-234 162.5498,-234 171.4848,-234\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"171.7564,-237.5001 181.7563,-234 171.7563,-230.5001 171.7564,-237.5001\"/>\n</g>\n<!-- /outputs/24 -->\n<g id=\"node18\" class=\"node\">\n<title>/outputs/24</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"355,-198 301,-198 301,-162 355,-162 355,-198\"/>\n<text text-anchor=\"start\" x=\"313\" y=\"-177\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Concat</text>\n</g>\n<!-- /outputs/21&#45;&gt;/outputs/24 -->\n<g id=\"edge17\" class=\"edge\">\n<title>/outputs/21&#45;&gt;/outputs/24</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M243.1486,-180C257.7011,-180 275.2228,-180 290.5561,-180\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"290.7584,-183.5001 300.7584,-180 290.7583,-176.5001 290.7584,-183.5001\"/>\n</g>\n<!-- /outputs/22&#45;&gt;/outputs/24 -->\n<g id=\"edge18\" class=\"edge\">\n<title>/outputs/22&#45;&gt;/outputs/24</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M233.2638,-90.1982C236.5395,-93.1217 239.8751,-96.1297 243,-99 263.0487,-117.4152 285.3116,-138.6603 302.0225,-154.7713\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"299.6578,-157.3533 309.2818,-161.7837 304.5212,-152.3187 299.6578,-157.3533\"/>\n</g>\n<!-- /outputs/23&#45;&gt;/outputs/24 -->\n<g id=\"edge19\" class=\"edge\">\n<title>/outputs/23&#45;&gt;/outputs/24</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M243.1486,-219.6708C257.9809,-212.7362 275.8978,-204.3594 291.4382,-197.0938\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"293.1819,-200.1423 300.7584,-192.7363 290.2172,-193.8011 293.1819,-200.1423\"/>\n</g>\n<!-- /outputs/26 -->\n<g id=\"node20\" class=\"node\">\n<title>/outputs/26</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"498,-180 444,-180 444,-144 498,-144 498,-180\"/>\n<text text-anchor=\"start\" x=\"454\" y=\"-159\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Reshape</text>\n</g>\n<!-- /outputs/24&#45;&gt;/outputs/26 -->\n<g id=\"edge20\" class=\"edge\">\n<title>/outputs/24&#45;&gt;/outputs/26</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M355.2338,-176.572C377.4983,-173.7694 409.2039,-169.7785 433.8066,-166.6817\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"434.383,-170.1368 443.8676,-165.4153 433.5087,-163.1916 434.383,-170.1368\"/>\n</g>\n<!-- /outputs/25 -->\n<g id=\"node19\" class=\"node\">\n<title>/outputs/25</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"355,-144 301,-144 301,-108 355,-108 355,-144\"/>\n<text text-anchor=\"start\" x=\"310\" y=\"-123\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/25&#45;&gt;/outputs/26 -->\n<g id=\"edge21\" class=\"edge\">\n<title>/outputs/25&#45;&gt;/outputs/26</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M355.2338,-132.8561C377.5977,-138.4861 409.4871,-146.5142 434.1357,-152.7195\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"433.3157,-156.1222 443.8676,-155.1695 435.0247,-149.334 433.3157,-156.1222\"/>\n</g>\n<!-- /outputs/27 -->\n<g id=\"node21\" class=\"node\">\n<title>/outputs/27</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"644,-180 590,-180 590,-144 644,-144 644,-180\"/>\n<text text-anchor=\"start\" x=\"605\" y=\"-159\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Shape</text>\n</g>\n<!-- /outputs/26&#45;&gt;/outputs/27 -->\n<g id=\"edge22\" class=\"edge\">\n<title>/outputs/26&#45;&gt;/outputs/27</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M498.1193,-162C521.0227,-162 554.0591,-162 579.4958,-162\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"579.5305,-165.5001 589.5305,-162 579.5304,-158.5001 579.5305,-165.5001\"/>\n</g>\n<!-- /outputs/31 -->\n<g id=\"node25\" class=\"node\">\n<title>/outputs/31</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1099,-133 1045,-133 1045,-97 1099,-97 1099,-133\"/>\n<text text-anchor=\"start\" x=\"1060\" y=\"-112\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Equal</text>\n</g>\n<!-- /outputs/26&#45;&gt;/outputs/31 -->\n<g id=\"edge23\" class=\"edge\">\n<title>/outputs/26&#45;&gt;/outputs/31</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M484.9609,-143.6585C508.6768,-114.7893 559.8556,-62 617,-62 617,-62 617,-62 934,-62 970.388,-62 1008.7997,-78.2572 1035.8319,-92.7638\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1034.4545,-96.0021 1044.897,-97.7924 1037.8501,-89.8808 1034.4545,-96.0021\"/>\n</g>\n<!-- /outputs/32 -->\n<g id=\"node26\" class=\"node\">\n<title>/outputs/32</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1237,-187 1183,-187 1183,-151 1237,-151 1237,-187\"/>\n<text text-anchor=\"start\" x=\"1197\" y=\"-166\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Where</text>\n</g>\n<!-- /outputs/26&#45;&gt;/outputs/32 -->\n<g id=\"edge24\" class=\"edge\">\n<title>/outputs/26&#45;&gt;/outputs/32</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M498.1843,-176.0515C526.9175,-189.5792 573.9476,-208 617,-208 617,-208 617,-208 1072,-208 1107.3669,-208 1145.9759,-195.9913 1173.3388,-185.3012\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1174.922,-188.4366 1182.8842,-181.4471 1172.3011,-181.9457 1174.922,-188.4366\"/>\n</g>\n<!-- /outputs/28 -->\n<g id=\"node22\" class=\"node\">\n<title>/outputs/28</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"823,-180 735,-180 735,-144 823,-144 823,-180\"/>\n<text text-anchor=\"start\" x=\"743\" y=\"-159\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">ConstantOfShape</text>\n</g>\n<!-- /outputs/27&#45;&gt;/outputs/28 -->\n<g id=\"edge25\" class=\"edge\">\n<title>/outputs/27&#45;&gt;/outputs/28</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M644.1152,-162C666.0793,-162 697.7176,-162 724.8331,-162\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"724.9744,-165.5001 734.9743,-162 724.9743,-158.5001 724.9744,-165.5001\"/>\n</g>\n<!-- /outputs/30 -->\n<g id=\"node24\" class=\"node\">\n<title>/outputs/30</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"961,-133 907,-133 907,-97 961,-97 961,-133\"/>\n<text text-anchor=\"start\" x=\"925\" y=\"-112\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Mul</text>\n</g>\n<!-- /outputs/28&#45;&gt;/outputs/30 -->\n<g id=\"edge26\" class=\"edge\">\n<title>/outputs/28&#45;&gt;/outputs/30</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M823.0022,-148.6574C846.4418,-141.5499 874.9436,-132.9074 897.1963,-126.1598\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"898.3861,-129.4565 906.9401,-123.2052 896.3548,-122.7577 898.3861,-129.4565\"/>\n</g>\n<!-- /outputs/28&#45;&gt;/outputs/32 -->\n<g id=\"edge27\" class=\"edge\">\n<title>/outputs/28&#45;&gt;/outputs/32</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M823.0991,-162.7162C908.2071,-164.0985 1093.6873,-167.1109 1172.9167,-168.3977\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1172.8716,-171.8974 1182.9271,-168.5603 1172.9853,-164.8983 1172.8716,-171.8974\"/>\n</g>\n<!-- /outputs/29 -->\n<g id=\"node23\" class=\"node\">\n<title>/outputs/29</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"806,-126 752,-126 752,-90 806,-90 806,-126\"/>\n<text text-anchor=\"start\" x=\"761\" y=\"-105\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/29&#45;&gt;/outputs/30 -->\n<g id=\"edge28\" class=\"edge\">\n<title>/outputs/29&#45;&gt;/outputs/30</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M806.3541,-109.2353C831.4769,-110.3699 868.9202,-112.0609 896.7638,-113.3184\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"896.7716,-116.8222 906.9194,-113.777 897.0875,-109.8293 896.7716,-116.8222\"/>\n</g>\n<!-- /outputs/30&#45;&gt;/outputs/31 -->\n<g id=\"edge29\" class=\"edge\">\n<title>/outputs/30&#45;&gt;/outputs/31</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M961.2644,-115C982.2445,-115 1011.4706,-115 1034.6299,-115\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1034.7923,-118.5001 1044.7923,-115 1034.7922,-111.5001 1034.7923,-118.5001\"/>\n</g>\n<!-- /outputs/31&#45;&gt;/outputs/32 -->\n<g id=\"edge30\" class=\"edge\">\n<title>/outputs/31&#45;&gt;/outputs/32</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1099.2961,-125.4741C1117.8458,-132.6147 1142.9284,-142.3141 1165,-151 1167.716,-152.0688 1170.5221,-153.1785 1173.3441,-154.2983\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1172.3066,-157.6524 1182.892,-158.1 1174.8961,-151.1489 1172.3066,-157.6524\"/>\n</g>\n<!-- /outputs/32&#45;&gt;/outputs/33 -->\n<g id=\"edge31\" class=\"edge\">\n<title>/outputs/32&#45;&gt;/outputs/33</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1237.2447,-178.4391C1255.6594,-184.819 1280.1915,-193.3183 1300.3193,-200.2917\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1299.2056,-203.6099 1309.8004,-203.5765 1301.4972,-196.9956 1299.2056,-203.6099\"/>\n</g>\n<!-- /outputs/41 -->\n<g id=\"node35\" class=\"node\">\n<title>/outputs/41</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1512,-231 1458,-231 1458,-195 1512,-195 1512,-231\"/>\n<text text-anchor=\"start\" x=\"1468\" y=\"-210\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n</g>\n<!-- /outputs/33&#45;&gt;/outputs/41 -->\n<g id=\"edge32\" class=\"edge\">\n<title>/outputs/33&#45;&gt;/outputs/41</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1364.1455,-213C1387.5424,-213 1421.5551,-213 1447.5413,-213\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1447.7764,-216.5001 1457.7764,-213 1447.7764,-209.5001 1447.7764,-216.5001\"/>\n<text text-anchor=\"middle\" x=\"1411\" y=\"-216\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x784x784</text>\n</g>\n<!-- /outputs/34 -->\n<g id=\"node28\" class=\"node\">\n<title>/outputs/34</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"145,-414 91,-414 91,-378 145,-378 145,-414\"/>\n<text text-anchor=\"start\" x=\"100\" y=\"-393\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/37 -->\n<g id=\"node31\" class=\"node\">\n<title>/outputs/37</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-414 182,-414 182,-378 243,-378 243,-414\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-393\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/34&#45;&gt;/outputs/37 -->\n<g id=\"edge33\" class=\"edge\">\n<title>/outputs/34&#45;&gt;/outputs/37</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M145.0799,-396C153.3121,-396 162.5498,-396 171.4848,-396\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"171.7564,-399.5001 181.7563,-396 171.7563,-392.5001 171.7564,-399.5001\"/>\n</g>\n<!-- /outputs/35 -->\n<g id=\"node29\" class=\"node\">\n<title>/outputs/35</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"145,-468 91,-468 91,-432 145,-432 145,-468\"/>\n<text text-anchor=\"start\" x=\"100\" y=\"-447\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/38 -->\n<g id=\"node32\" class=\"node\">\n<title>/outputs/38</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-468 182,-468 182,-432 243,-432 243,-468\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-447\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/35&#45;&gt;/outputs/38 -->\n<g id=\"edge34\" class=\"edge\">\n<title>/outputs/35&#45;&gt;/outputs/38</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M145.0799,-450C153.3121,-450 162.5498,-450 171.4848,-450\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"171.7564,-453.5001 181.7563,-450 171.7563,-446.5001 171.7564,-453.5001\"/>\n</g>\n<!-- /outputs/39 -->\n<g id=\"node33\" class=\"node\">\n<title>/outputs/39</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"355,-395 301,-395 301,-359 355,-359 355,-395\"/>\n<text text-anchor=\"start\" x=\"313\" y=\"-374\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Concat</text>\n</g>\n<!-- /outputs/36&#45;&gt;/outputs/39 -->\n<g id=\"edge35\" class=\"edge\">\n<title>/outputs/36&#45;&gt;/outputs/39</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M243.1486,-351.2875C257.841,-355.7397 275.5601,-361.1091 290.9978,-365.7872\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"290.1731,-369.1944 300.7584,-368.745 292.2032,-362.4952 290.1731,-369.1944\"/>\n</g>\n<!-- /outputs/37&#45;&gt;/outputs/39 -->\n<g id=\"edge36\" class=\"edge\">\n<title>/outputs/37&#45;&gt;/outputs/39</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M243.1486,-390.9582C257.7011,-388.5643 275.2228,-385.682 290.5561,-383.1596\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"291.4592,-386.5582 300.7584,-381.4813 290.3229,-379.651 291.4592,-386.5582\"/>\n</g>\n<!-- /outputs/38&#45;&gt;/outputs/39 -->\n<g id=\"edge37\" class=\"edge\">\n<title>/outputs/38&#45;&gt;/outputs/39</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M241.0506,-431.9551C256.4945,-422.194 275.7192,-410.0433 292.1615,-399.6512\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"294.385,-402.3863 300.9682,-394.085 290.6451,-396.4691 294.385,-402.3863\"/>\n</g>\n<!-- /outputs/40 -->\n<g id=\"node34\" class=\"node\">\n<title>/outputs/40</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"570,-381 516,-381 516,-345 570,-345 570,-381\"/>\n<text text-anchor=\"start\" x=\"526\" y=\"-360\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Reshape</text>\n</g>\n<!-- /outputs/39&#45;&gt;/outputs/40 -->\n<g id=\"edge38\" class=\"edge\">\n<title>/outputs/39&#45;&gt;/outputs/40</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M355.2299,-375.2269C393.4359,-372.7391 463.07,-368.2047 505.9601,-365.4119\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"506.2398,-368.9012 515.9912,-364.7587 505.7849,-361.916 506.2398,-368.9012\"/>\n</g>\n<!-- /outputs/40&#45;&gt;/outputs/41 -->\n<g id=\"edge39\" class=\"edge\">\n<title>/outputs/40&#45;&gt;/outputs/41</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M570.184,-362.4108C599.7782,-361.8208 648.4996,-361 690.5,-361 690.5,-361 690.5,-361 1337,-361 1404.5649,-361 1452.0582,-282.9243 1472.9936,-240.1886\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1476.1587,-241.6825 1477.2851,-231.1477 1469.835,-238.6807 1476.1587,-241.6825\"/>\n<text text-anchor=\"middle\" x=\"1003\" y=\"-364\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x784x1</text>\n</g>\n<!-- /outputs/46 -->\n<g id=\"node40\" class=\"node\">\n<title>/outputs/46</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1652,-108 1598,-108 1598,-72 1652,-72 1652,-108\"/>\n<text text-anchor=\"start\" x=\"1608\" y=\"-87\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Reshape</text>\n</g>\n<!-- /outputs/41&#45;&gt;/outputs/46 -->\n<g id=\"edge40\" class=\"edge\">\n<title>/outputs/41&#45;&gt;/outputs/46</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1505.6575,-194.8509C1529.7487,-173.6851 1569.7047,-138.5808 1596.6774,-114.8834\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1599.0554,-117.4532 1604.2578,-108.2235 1594.4352,-112.1944 1599.0554,-117.4532\"/>\n<text text-anchor=\"middle\" x=\"1554\" y=\"-176\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x784x1</text>\n</g>\n<!-- /outputs/42 -->\n<g id=\"node36\" class=\"node\">\n<title>/outputs/42</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"145,-36 91,-36 91,0 145,0 145,-36\"/>\n<text text-anchor=\"start\" x=\"100\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/44 -->\n<g id=\"node38\" class=\"node\">\n<title>/outputs/44</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"243,-36 182,-36 182,0 243,0 243,-36\"/>\n<text text-anchor=\"start\" x=\"190.5\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Unsqueeze</text>\n</g>\n<!-- /outputs/42&#45;&gt;/outputs/44 -->\n<g id=\"edge41\" class=\"edge\">\n<title>/outputs/42&#45;&gt;/outputs/44</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M145.0799,-18C153.3121,-18 162.5498,-18 171.4848,-18\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"171.7564,-21.5001 181.7563,-18 171.7563,-14.5001 171.7564,-21.5001\"/>\n</g>\n<!-- /outputs/45 -->\n<g id=\"node39\" class=\"node\">\n<title>/outputs/45</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"355,-42 301,-42 301,-6 355,-6 355,-42\"/>\n<text text-anchor=\"start\" x=\"313\" y=\"-21\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Concat</text>\n</g>\n<!-- /outputs/43&#45;&gt;/outputs/45 -->\n<g id=\"edge42\" class=\"edge\">\n<title>/outputs/43&#45;&gt;/outputs/45</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M232.8832,-107.9992C251.4889,-91.5682 279.1631,-67.1287 299.8703,-48.8418\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"302.2533,-51.4068 307.4321,-42.1639 297.6197,-46.1599 302.2533,-51.4068\"/>\n</g>\n<!-- /outputs/44&#45;&gt;/outputs/45 -->\n<g id=\"edge43\" class=\"edge\">\n<title>/outputs/44&#45;&gt;/outputs/45</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M243.1486,-19.5921C257.7011,-20.3481 275.2228,-21.2583 290.5561,-22.0549\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"290.5902,-25.5613 300.7584,-22.5849 290.9535,-18.5707 290.5902,-25.5613\"/>\n</g>\n<!-- /outputs/45&#45;&gt;/outputs/46 -->\n<g id=\"edge44\" class=\"edge\">\n<title>/outputs/45&#45;&gt;/outputs/46</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M355.148,-24C383.9122,-24 430.6571,-24 471,-24 471,-24 471,-24 1485,-24 1525.7348,-24 1567.3183,-47.1434 1594.4614,-66.0208\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1592.5818,-68.9799 1602.7474,-71.9652 1596.6623,-63.2921 1592.5818,-68.9799\"/>\n</g>\n<!-- /outputs/48 -->\n<g id=\"node42\" class=\"node\">\n<title>/outputs/48</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1781,-71 1727,-71 1727,-35 1781,-35 1781,-71\"/>\n<text text-anchor=\"start\" x=\"1737\" y=\"-50\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n</g>\n<!-- /outputs/46&#45;&gt;/outputs/48 -->\n<g id=\"edge45\" class=\"edge\">\n<title>/outputs/46&#45;&gt;/outputs/48</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1652.0426,-82.2436C1670.963,-76.8168 1696.4641,-69.5025 1717.2292,-63.5467\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1718.3455,-66.8677 1726.993,-60.7462 1716.4156,-60.139 1718.3455,-66.8677\"/>\n<text text-anchor=\"middle\" x=\"1690.5\" y=\"-78\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x784</text>\n</g>\n<!-- /outputs/47 -->\n<g id=\"node41\" class=\"node\">\n<title>/outputs/47</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1654,-54 1596,-54 1596,-18 1654,-18 1654,-54\"/>\n<text text-anchor=\"start\" x=\"1604\" y=\"-33\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n</g>\n<!-- /outputs/47&#45;&gt;/outputs/48 -->\n<g id=\"edge46\" class=\"edge\">\n<title>/outputs/47&#45;&gt;/outputs/48</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1654.1582,-39.8013C1670.3193,-41.9125 1690.7736,-44.5915 1709,-47 1711.5297,-47.3343 1714.1392,-47.68 1716.768,-48.0288\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1716.4997,-51.5239 1726.8739,-49.3727 1717.4225,-44.585 1716.4997,-51.5239\"/>\n<text text-anchor=\"middle\" x=\"1690.5\" y=\"-50\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">784x10</text>\n</g>\n<!-- /outputs/49 -->\n<g id=\"node43\" class=\"node\">\n<title>/outputs/49</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"1903,-71 1849,-71 1849,-35 1903,-35 1903,-71\"/>\n<text text-anchor=\"start\" x=\"1859\" y=\"-50\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Softmax</text>\n</g>\n<!-- /outputs/48&#45;&gt;/outputs/49 -->\n<g id=\"edge47\" class=\"edge\">\n<title>/outputs/48&#45;&gt;/outputs/49</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1781.0758,-53C1798.0553,-53 1820.1767,-53 1838.7924,-53\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1839,-56.5001 1848.9999,-53 1838.9999,-49.5001 1839,-56.5001\"/>\n<text text-anchor=\"middle\" x=\"1815\" y=\"-56\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x10</text>\n</g>\n</g>\n</svg>\n"
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q6Y7VQDug0A0",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wkjaWvzpg0A0",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 639
        },
        "outputId": "d4e6052d-bd2b-4c60-f4e8-23b0ff1cf098"
      },
      "source": [
        "def compute_acc(model, data_loader, device):\n",
        "    correct_pred, num_examples = 0, 0\n",
        "    for features, targets in data_loader:\n",
        "        features = features.to(device)\n",
        "        targets = targets.to(device)\n",
        "        logits, probas = model(features)\n",
        "        _, predicted_labels = torch.max(probas, 1)\n",
        "        num_examples += targets.size(0)\n",
        "        correct_pred += (predicted_labels == targets).sum()\n",
        "    return correct_pred.float()/num_examples * 100\n",
        "    \n",
        "\n",
        "start_time = time.time()\n",
        "\n",
        "cost_list = []\n",
        "train_acc_list, valid_acc_list = [], []\n",
        "\n",
        "\n",
        "for epoch in range(NUM_EPOCHS):\n",
        "    \n",
        "    model.train()\n",
        "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
        "        \n",
        "        features = features.to(DEVICE)\n",
        "        targets = targets.to(DEVICE)\n",
        "            \n",
        "        ### FORWARD AND BACK PROP\n",
        "        logits, probas = model(features)\n",
        "        cost = F.cross_entropy(logits, targets)\n",
        "        optimizer.zero_grad()\n",
        "        \n",
        "        cost.backward()\n",
        "        \n",
        "        ### UPDATE MODEL PARAMETERS\n",
        "        optimizer.step()\n",
        "        \n",
        "        #################################################\n",
        "        ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n",
        "        ################################################\n",
        "        cost_list.append(cost.item())\n",
        "        if not batch_idx % 150:\n",
        "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
        "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
        "                   f' Cost: {cost:.4f}')\n",
        "\n",
        "        \n",
        "\n",
        "    model.eval()\n",
        "    with torch.set_grad_enabled(False): # save memory during inference\n",
        "        \n",
        "        train_acc = compute_acc(model, train_loader, device=DEVICE)\n",
        "        valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n",
        "        \n",
        "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n",
        "              f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n",
        "        \n",
        "        train_acc_list.append(train_acc)\n",
        "        valid_acc_list.append(valid_acc)\n",
        "        \n",
        "    elapsed = (time.time() - start_time)/60\n",
        "    print(f'Time elapsed: {elapsed:.2f} min')\n",
        "  \n",
        "elapsed = (time.time() - start_time)/60\n",
        "print(f'Total Training Time: {elapsed:.2f} min')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 001/050 | Batch 000/461 | Cost: 24.2727\n",
            "Epoch: 001/050 | Batch 150/461 | Cost: 2.2706\n",
            "Epoch: 001/050 | Batch 300/461 | Cost: 1.8713\n",
            "Epoch: 001/050 | Batch 450/461 | Cost: 1.5048\n",
            "Epoch: 001/050\n",
            "Train ACC: 50.39 | Validation ACC: 54.80\n",
            "Time elapsed: 3.08 min\n",
            "Epoch: 002/050 | Batch 000/461 | Cost: 1.4445\n",
            "Epoch: 002/050 | Batch 150/461 | Cost: 1.3288\n",
            "Epoch: 002/050 | Batch 300/461 | Cost: 1.1868\n",
            "Epoch: 002/050 | Batch 450/461 | Cost: 1.2040\n",
            "Epoch: 002/050\n",
            "Train ACC: 67.68 | Validation ACC: 71.40\n",
            "Time elapsed: 6.16 min\n",
            "Epoch: 003/050 | Batch 000/461 | Cost: 1.2128\n",
            "Epoch: 003/050 | Batch 150/461 | Cost: 0.9953\n",
            "Epoch: 003/050 | Batch 300/461 | Cost: 0.9818\n",
            "Epoch: 003/050 | Batch 450/461 | Cost: 1.0487\n",
            "Epoch: 003/050\n",
            "Train ACC: 68.09 | Validation ACC: 73.40\n",
            "Time elapsed: 9.23 min\n",
            "Epoch: 004/050 | Batch 000/461 | Cost: 1.0444\n",
            "Epoch: 004/050 | Batch 150/461 | Cost: 0.9064\n",
            "Epoch: 004/050 | Batch 300/461 | Cost: 0.9152\n",
            "Epoch: 004/050 | Batch 450/461 | Cost: 0.7396\n",
            "Epoch: 004/050\n",
            "Train ACC: 76.10 | Validation ACC: 80.20\n",
            "Time elapsed: 12.30 min\n",
            "Epoch: 005/050 | Batch 000/461 | Cost: 0.7698\n",
            "Epoch: 005/050 | Batch 150/461 | Cost: 0.8356\n",
            "Epoch: 005/050 | Batch 300/461 | Cost: 0.6544\n",
            "Epoch: 005/050 | Batch 450/461 | Cost: 0.8700\n",
            "Epoch: 005/050\n",
            "Train ACC: 80.30 | Validation ACC: 84.10\n",
            "Time elapsed: 15.39 min\n",
            "Epoch: 006/050 | Batch 000/461 | Cost: 0.6292\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yc5pk88Cg0A2",
        "colab_type": "text"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PZNRwHORg0A3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# last adjacency matrix\n",
        "\n",
        "plt.imshow(model.A.to('cpu'));"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7S4ZDfUEg0A5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(cost_list, label='Minibatch cost')\n",
        "plt.plot(np.convolve(cost_list, \n",
        "                     np.ones(200,)/200, mode='valid'), \n",
        "         label='Running average')\n",
        "\n",
        "plt.ylabel('Cross Entropy')\n",
        "plt.xlabel('Iteration')\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5DPCOTp3g0A8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n",
        "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n",
        "\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J4D8ZsNug0A_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "with torch.set_grad_enabled(False):\n",
        "    test_acc = compute_acc(model=model,\n",
        "                           data_loader=test_loader,\n",
        "                           device=DEVICE)\n",
        "    \n",
        "    valid_acc = compute_acc(model=model,\n",
        "                            data_loader=valid_loader,\n",
        "                            device=DEVICE)\n",
        "    \n",
        "\n",
        "print(f'Validation ACC: {valid_acc:.2f}%')\n",
        "print(f'Test ACC: {test_acc:.2f}%')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jWaaObAEg0BB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%watermark -iv"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}